package org.geektimes.web.mvc;

import org.apache.commons.lang.StringUtils;
import org.geektimes.web.mvc.annotation.Controller;
import org.geektimes.web.mvc.annotation.ResponseBody;
import org.geektimes.web.mvc.context.ComponentContext;
import org.geektimes.web.mvc.scanner.ClasspathPackageScanner;
import javax.servlet.RequestDispatcher;
import javax.servlet.ServletConfig;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.HttpMethod;
import javax.ws.rs.Path;
import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.*;

import static java.util.Arrays.asList;
import static org.apache.commons.lang.StringUtils.substringAfter;

/**
 * 所有请求的入口，用于做路径映射
 */
public class FrontControllerServlet extends HttpServlet {

    /**
     * 扫描路径
     */
    private String scanPath;

    /**
     * 请求路径和 Controller 的映射关系缓存
     */
    private Map<String, Object> controllersMapping = new HashMap<>();

    /**
     * 请求路径和 {@link HandlerMethodInfo} 映射关系缓存
     */
    private Map<String, HandlerMethodInfo> handleMethodInfoMapping = new HashMap<>();

    private ComponentContext context=new ComponentContext();

    /**
     * 初始化 Servlet
     *
     * @param servletConfig
     */
    @Override
    public void init(ServletConfig servletConfig) {
        try {
            scanPath=servletConfig.getInitParameter("scanpath");
            initHandleMethods();
            //初始话容器
            context.init(servletConfig.getServletContext());
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * 扫描指定路径所有文件，找到带有@Path的注解，把他视为Controller
     * @throws Exception
     */
    private void initHandleMethods() throws Exception {
        ClasspathPackageScanner scan = new ClasspathPackageScanner(scanPath);
        //获取指定包下的所有文件路径
        List<String> packages=scan.getFullyQualifiedClassNameList();
        for (String classPath : packages) {
            Class controllerClass = Class.forName(classPath);
            //判断有没有Controller注解，有的话证明是一个路径映射
            Controller controllerClassAnnotation = (Controller) controllerClass.getAnnotation(Controller.class);
            if(controllerClassAnnotation!=null){
                // 实例化对象
                Path pathFromClass = (Path) controllerClass.getAnnotation(Path.class);
                Object controller= controllerClass.newInstance();
                context.setBean(controller.getClass().getName(),controller);
                Method[] publicMethods = controllerClass.getMethods();
                // 处理方法支持的 HTTP 方法集合
                for (Method method : publicMethods) {
                    String requestPath = pathFromClass.value();
                    Set<String> supportedHttpMethods = findSupportedHttpMethods(method);
                    Path pathFromMethod = method.getAnnotation(Path.class);
                    if (pathFromMethod != null) {
                        requestPath += pathFromMethod.value();
                        handleMethodInfoMapping.put(requestPath,
                                new HandlerMethodInfo(requestPath, method, supportedHttpMethods));
                        controllersMapping.put(requestPath, controller);
                    }
                }
            }
        }
    }

    /**
     * 获取处理方法中标注的 HTTP方法集合
     *
     * @param method 处理方法
     * @return
     */
    private Set<String> findSupportedHttpMethods(Method method) {
        Set<String> supportedHttpMethods = new LinkedHashSet<>();
        for (Annotation annotationFromMethod : method.getAnnotations()) {
            HttpMethod httpMethod = annotationFromMethod.annotationType().getAnnotation(HttpMethod.class);
            if (httpMethod != null) {
                supportedHttpMethods.add(httpMethod.value());
            }
        }

        if (supportedHttpMethods.isEmpty()) {
            supportedHttpMethods.addAll(asList(HttpMethod.GET, HttpMethod.POST,
                    HttpMethod.PUT, HttpMethod.DELETE, HttpMethod.HEAD, HttpMethod.OPTIONS));
        }

        return supportedHttpMethods;
    }

    /**
     * SCWCD
     *
     * @param request
     * @param response
     * @throws ServletException
     * @throws IOException
     */
    @Override
    public void service(HttpServletRequest request, HttpServletResponse response)
            throws ServletException, IOException {
        // 建立映射关系
        // requestURI = /a/hello/world
        String requestURI = request.getRequestURI();
        // contextPath  = /a or "/" or ""
        String servletContextPath = request.getContextPath();
        String prefixPath = servletContextPath;
        // 映射路径（子路径）
        String requestMappingPath = substringAfter(requestURI,
                StringUtils.replace(prefixPath, "//", "/"));
        // 映射到 Controller
        Object controller = controllersMapping.get(requestMappingPath);
        if (controller != null) {
            HandlerMethodInfo handlerMethodInfo = handleMethodInfoMapping.get(requestMappingPath);
            try {
                if (handlerMethodInfo != null) {
                    String httpMethod = request.getMethod();
                    if (!handlerMethodInfo.getSupportedHttpMethods().contains(httpMethod)) {
                        // HTTP 方法不支持
                        response.setStatus(HttpServletResponse.SC_METHOD_NOT_ALLOWED);
                        return;
                    }
                    //如果没有ResponBody则做页面forward
                    if (handlerMethodInfo.getHandlerMethod().getAnnotation(ResponseBody.class)==null) {
                        String viewPath= (String) handlerMethodInfo.getHandlerMethod().invoke(controller,request,response);
                        ServletContext servletContext = request.getServletContext();
                        if (!viewPath.startsWith("/")) {
                            viewPath = "/" + viewPath;
                        }
                        RequestDispatcher requestDispatcher = servletContext.getRequestDispatcher(viewPath);
                        requestDispatcher.forward(request, response);
                        return;
                    } else if (1==0) {
                        // TODO
                    }

                }
            } catch (Throwable throwable) {
                if (throwable.getCause() instanceof IOException) {
                    throw (IOException) throwable.getCause();
                } else {
                    throw new ServletException(throwable.getCause());
                }
            }
        }
    }

}
